{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CIFAR10 Classification with Data-efficient image Transformers(DeiT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, random_split, DataLoader\n",
    "\n",
    "from torchvision import datasets, transforms, models\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Dataset, random_split, DataLoader\n",
    "from torchvision.utils import save_image\n",
    "\n",
    "from torchsummary import summary\n",
    "\n",
    "import spacy\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import os\n",
    "import time\n",
    "import math\n",
    "from PIL import Image\n",
    "import glob\n",
    "from IPython.display import display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "LR = 5e-5\n",
    "NUM_EPOCHES = 25"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean, std = (0.5,), (0.5,)\n",
    "\n",
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                                transforms.Normalize(mean, std)\n",
    "                              ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "trainset = datasets.CIFAR10('../data/CIFAR10/', download=True, train=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "testset = datasets.CIFAR10('../data/CIFAR10/', download=True, train=False, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-training Teacher Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://github.com/UdbhavPrasad072300/Transformer-Implementations/blob/main/pre-train/VGG16_CIFAR10.ipynb\n",
    "teacher_model = torch.load(\"../trained_models/vgg16_cifar10.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n"
     ]
    }
   ],
   "source": [
    "teacher_model.preprocess_flag = False\n",
    "print(teacher_model.preprocess_flag)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DeiT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_package.models import DeiT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeiT(\n",
       "  (dropout_layer): Dropout(p=0.2, inplace=False)\n",
       "  (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "  (embeddings): Linear(in_features=48, out_features=512, bias=True)\n",
       "  (teacher_model): VGG16_classifier(\n",
       "    (vgg16): VGG(\n",
       "      (features): Sequential(\n",
       "        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (1): ReLU(inplace=True)\n",
       "        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (3): ReLU(inplace=True)\n",
       "        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (6): ReLU(inplace=True)\n",
       "        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (8): ReLU(inplace=True)\n",
       "        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (11): ReLU(inplace=True)\n",
       "        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (13): ReLU(inplace=True)\n",
       "        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (15): ReLU(inplace=True)\n",
       "        (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (18): ReLU(inplace=True)\n",
       "        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (20): ReLU(inplace=True)\n",
       "        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (22): ReLU(inplace=True)\n",
       "        (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "        (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (25): ReLU(inplace=True)\n",
       "        (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (27): ReLU(inplace=True)\n",
       "        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "        (29): ReLU(inplace=True)\n",
       "        (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      )\n",
       "      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
       "      (classifier): Sequential(\n",
       "        (0): Linear(in_features=25088, out_features=2048, bias=True)\n",
       "        (1): ReLU()\n",
       "        (2): Dropout(p=0.3, inplace=False)\n",
       "        (3): Linear(in_features=2048, out_features=1024, bias=True)\n",
       "        (4): ReLU()\n",
       "        (5): Dropout(p=0.3, inplace=False)\n",
       "        (6): Linear(in_features=1024, out_features=512, bias=True)\n",
       "        (7): ReLU()\n",
       "        (8): Dropout(p=0.3, inplace=False)\n",
       "        (9): Linear(in_features=512, out_features=10, bias=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (encoders): ModuleList(\n",
       "    (0): VisionEncoder(\n",
       "      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (attention): MultiHeadAttention(\n",
       "        (dropout_layer): Dropout(p=0.2, inplace=False)\n",
       "        (Q): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (K): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (V): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (linear): Linear(in_features=512, out_features=512, bias=True)\n",
       "      )\n",
       "      (mlp): Sequential(\n",
       "        (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "        (1): GELU()\n",
       "        (2): Dropout(p=0.2, inplace=False)\n",
       "        (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "        (4): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (1): VisionEncoder(\n",
       "      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (attention): MultiHeadAttention(\n",
       "        (dropout_layer): Dropout(p=0.2, inplace=False)\n",
       "        (Q): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (K): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (V): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (linear): Linear(in_features=512, out_features=512, bias=True)\n",
       "      )\n",
       "      (mlp): Sequential(\n",
       "        (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "        (1): GELU()\n",
       "        (2): Dropout(p=0.2, inplace=False)\n",
       "        (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "        (4): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (2): VisionEncoder(\n",
       "      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (attention): MultiHeadAttention(\n",
       "        (dropout_layer): Dropout(p=0.2, inplace=False)\n",
       "        (Q): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (K): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (V): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (linear): Linear(in_features=512, out_features=512, bias=True)\n",
       "      )\n",
       "      (mlp): Sequential(\n",
       "        (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "        (1): GELU()\n",
       "        (2): Dropout(p=0.2, inplace=False)\n",
       "        (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "        (4): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (3): VisionEncoder(\n",
       "      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      (attention): MultiHeadAttention(\n",
       "        (dropout_layer): Dropout(p=0.2, inplace=False)\n",
       "        (Q): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (K): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (V): Linear(in_features=512, out_features=512, bias=True)\n",
       "        (linear): Linear(in_features=512, out_features=512, bias=True)\n",
       "      )\n",
       "      (mlp): Sequential(\n",
       "        (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "        (1): GELU()\n",
       "        (2): Dropout(p=0.2, inplace=False)\n",
       "        (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "        (4): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=512, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image_size = 32\n",
    "channel_size = 3\n",
    "patch_size = 4\n",
    "embed_size = 512\n",
    "num_heads = 8\n",
    "classes = 10\n",
    "num_layers = 4\n",
    "hidden_size = 512\n",
    "dropout = 0.2\n",
    "\n",
    "model = DeiT(image_size=image_size, \n",
    "             channel_size=channel_size, \n",
    "             patch_size=patch_size, \n",
    "             embed_size=embed_size, \n",
    "             num_heads=num_heads, \n",
    "             classes=classes, \n",
    "             num_layers=num_layers,\n",
    "             hidden_size=hidden_size,\n",
    "             teacher_model=teacher_model,\n",
    "             dropout=dropout\n",
    "            ).to(device)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input Image Dimensions: torch.Size([32, 3, 32, 32])\n",
      "Label Dimensions: torch.Size([32])\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Student Output Dimensions: torch.Size([32, 10])\n",
      "Teacher Output Dimensions: torch.Size([32, 10])\n"
     ]
    }
   ],
   "source": [
    "for img, label in trainloader:\n",
    "    img = img.to(device)\n",
    "    label = label.to(device)\n",
    "    \n",
    "    print(\"Input Image Dimensions: {}\".format(img.size()))\n",
    "    print(\"Label Dimensions: {}\".format(label.size()))\n",
    "    print(\"-\"*100)\n",
    "    \n",
    "    out, teacher_out = model(img)\n",
    "    \n",
    "    print(\"Student Output Dimensions: {}\".format(out.size()))\n",
    "    print(\"Teacher Output Dimensions: {}\".format(teacher_out.size()))\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training with Hard Label Distillation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_package.loss_functions.loss import Hard_Distillation_Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = Hard_Distillation_Loss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------------------------------\n",
      "Epoch: 1 Train mean loss: 1644.12019843\n",
      "       Train Accuracy%:  32.28 == 16140 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 2 Train mean loss: 1412.12571043\n",
      "       Train Accuracy%:  43.754 == 21877 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 3 Train mean loss: 1342.61206454\n",
      "       Train Accuracy%:  46.948 == 23474 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 4 Train mean loss: 1295.64139086\n",
      "       Train Accuracy%:  49.306 == 24653 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 5 Train mean loss: 1259.07025161\n",
      "       Train Accuracy%:  51.146 == 25573 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 6 Train mean loss: 1228.55797020\n",
      "       Train Accuracy%:  52.444 == 26222 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 7 Train mean loss: 1201.34935898\n",
      "       Train Accuracy%:  53.882 == 26941 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 8 Train mean loss: 1173.09531677\n",
      "       Train Accuracy%:  55.096 == 27548 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 9 Train mean loss: 1143.60485944\n",
      "       Train Accuracy%:  56.618 == 28309 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 10 Train mean loss: 1117.18866742\n",
      "       Train Accuracy%:  57.664 == 28832 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 11 Train mean loss: 1090.56170151\n",
      "       Train Accuracy%:  58.954 == 29477 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 12 Train mean loss: 1064.91038817\n",
      "       Train Accuracy%:  60.15 == 30075 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 13 Train mean loss: 1038.63289067\n",
      "       Train Accuracy%:  61.414 == 30707 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 14 Train mean loss: 1016.84226078\n",
      "       Train Accuracy%:  62.382 == 31191 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 15 Train mean loss: 991.65151086\n",
      "       Train Accuracy%:  63.84 == 31920 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 16 Train mean loss: 961.17885591\n",
      "       Train Accuracy%:  64.846 == 32423 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 17 Train mean loss: 935.56182054\n",
      "       Train Accuracy%:  66.278 == 33139 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 18 Train mean loss: 911.77070428\n",
      "       Train Accuracy%:  67.166 == 33583 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 19 Train mean loss: 879.42079553\n",
      "       Train Accuracy%:  68.576 == 34288 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 20 Train mean loss: 855.68634495\n",
      "       Train Accuracy%:  69.844 == 34922 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 21 Train mean loss: 829.04679237\n",
      "       Train Accuracy%:  71.142 == 35571 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 22 Train mean loss: 800.25847559\n",
      "       Train Accuracy%:  72.304 == 36152 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 23 Train mean loss: 770.28103490\n",
      "       Train Accuracy%:  73.846 == 36923 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 24 Train mean loss: 739.04231864\n",
      "       Train Accuracy%:  75.072 == 37536 / 50000\n",
      "-------------------------------------------------\n",
      "-------------------------------------------------\n",
      "Epoch: 25 Train mean loss: 712.79009889\n",
      "       Train Accuracy%:  76.214 == 38107 / 50000\n",
      "-------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "loss_hist = {}\n",
    "loss_hist[\"train accuracy\"] = []\n",
    "loss_hist[\"train loss\"] = []\n",
    "\n",
    "for epoch in range(1, NUM_EPOCHES+1):\n",
    "    model.train()\n",
    "    \n",
    "    epoch_train_loss = 0\n",
    "        \n",
    "    y_true_train = []\n",
    "    y_pred_train = []\n",
    "        \n",
    "    for batch_idx, (img, labels) in enumerate(trainloader):\n",
    "        img = img.to(device)\n",
    "        labels = labels.to(device)\n",
    "        \n",
    "        preds, teacher_preds = model(img)\n",
    "        \n",
    "        loss = criterion(teacher_preds, preds, labels)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        y_pred_train.extend(preds.detach().argmax(dim=-1).tolist())\n",
    "        y_true_train.extend(labels.detach().tolist())\n",
    "            \n",
    "        epoch_train_loss += loss.item()\n",
    "    \n",
    "    loss_hist[\"train loss\"].append(epoch_train_loss)\n",
    "    \n",
    "    total_correct = len([True for x, y in zip(y_pred_train, y_true_train) if x==y])\n",
    "    total = len(y_pred_train)\n",
    "    accuracy = total_correct * 100 / total\n",
    "    \n",
    "    loss_hist[\"train accuracy\"].append(accuracy)\n",
    "    \n",
    "    print(\"-------------------------------------------------\")\n",
    "    print(\"Epoch: {} Train mean loss: {:.8f}\".format(epoch, epoch_train_loss))\n",
    "    print(\"       Train Accuracy%: \", accuracy, \"==\", total_correct, \"/\", total)\n",
    "    print(\"-------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAgM0lEQVR4nO3deXxV9Z3/8dcnG4EQ1oQ9IWGXXQwosogiTqW2UDe0HQWXau1mx19ntO1jOl2cqa0zHdppZ37iilVba62iXRTFKqASdhAQWbIRCNkxG1nvd/7Itc3PX4AEOTk397yfjwePe88hl3y+nNz7zvl+z/l+zTmHiIgEV4zfBYiIiL8UBCIiAacgEBEJOAWBiEjAKQhERAIuzu8COiIlJcVlZGT4XYaISLeybdu2Mudc6pm+rlsEQUZGBlu3bvW7DBGRbsXM8jvydeoaEhEJOAWBiEjAKQhERAJOQSAiEnAKAhGRgFMQiIgEnIJARCTgFAQiIhHoUEkN9/9hH5W1jZ5/r25xQ5mISBA0Nod4de9xns7OZ1NOBfGxxkWjBnL5xMGefl8FgYiIzwrK63hmcwHPbT1CeW0jI/r35J8+NZ7rLkgjNbmH599fQSAi4oPmlhDr9pfwdHYB6w+UEhtjLJwwiC9cNJJ5Y1KIibEuq0VBICLShY6dOMlvthzh2S0FFFc1MKRPIt+4fCzLZqYxtG9PX2pSEIiIeCwUcqw/WMpTmwp4Y38xDrhkXCr3Lx3JpeNTiYv197odBYGIiEdqG5p5fnshT7ydR05ZLSm9e3DXgtHcMDOdtAG9/C7vrxQEIiLn2JGKOp58N4/fbDlCdX0z09L68bMbpnPl5KEkxEXeVfsKAhGRc8A5x+bcCh5/O4+1+45jZlw5eQi3zs1kRnp/v8s7LQWBiMgnUN/Uwsu7jvH423nsK6qiX6947rxkNDddNJJh/fwZ/O0sBYGIyFkoqa7nqU0FPJOdT1lNI+MG9+ZHV09h6fTh9EyI9bu8TlEQiIh0QklVPSvXHeS5rUdoDjkuGz+IW+ZkMmfMQMy67tr/c0lBICLSAbUNzaxan8PDG3JobA5xw6w0bps7isyUJL9L+8QUBCIip9HUEuLZLUdY+fpBymoa+PTUofzjFePJiIIA+IiCQESkHc451u4r5sev7CentJZZGQN4+OYLOD/CrwA6GwoCEZGP2V5QyY/+9D5b8ioZnZrEwzdncfl5g7rtGMCZKAhERMJyy2p58NX9/Om946Qm9+DfPjeF67NG+D4FhNcUBCISeGU1DfzXuoM8nV1AQlwM/3D5OG6fl0lSj2B8RAajlSIi7ThUUs1Tmwr43bZCTja1cOOsNO5eOK5L1gCIJAoCEQmUxuYQa/cd56lNrauAJcTGsHjKEL62cCyjU3v7XZ4vFAQiEgjHTpzk15sL+M2WI5RWNzCif0/u/dQErs8awcDewToD+DgFgYhErVDIsfFQGb/alM+691vXAbh0/CBuumgk88elEtuFq4BFMgWBiESdytpGfretkKey88kvr2NgUgJ3XjKaz8+KrHUAIoWCQESixgfHq3l4Qw4v7TpGY3OImRn9uWfROD41eQg94rrXRHBdSUEgIt3aR+sAPLQ+hzf2l9AzPpbrLhjBTbNHMmFIH7/L6xYUBCLSLbWEHGv3Hueh9TnsPHKCAUkJ3LNoHDddNJL+SQl+l9eteBYEZjYeeLbNrlHAd4Enw/szgDzgeudcpVd1iEh0qW9q4fnthTy8Poe88jpGDuzFD5dO5roLRpAYr+6fs+FZEDjnPgCmA5hZLHAUeAG4D1jnnHvAzO4Lb9/rVR0iEh1O1DXy1KZ8nngnj7KaRqaO6Mt/f2EGfzdpiK7++YS6qmtoIXDYOZdvZkuABeH9q4E3URCIyCkcPXGSRzfk8pstBdQ1trBgfCp3zh/NRaMGRO0kcF2tq4LgBuDX4eeDnXNFAM65IjMb1N4LzOwO4A6A9PT0LilSRCLHnqMf8siGHF7eXYQBn502jC/OH8V5QzUAfK55HgRmlgB8FvhWZ17nnFsFrALIyspyHpQmIhEmFHK8eaCEh9fn8m5OOUkJsay4OINb52YyvJssBN8ddcUZwZXAdudccXi72MyGhs8GhgIlXVCDiESw+qYWXtxxlEc25nKopIahfRP59uIJ3DArnT6J8X6XF/W6Ighu5G/dQgAvAcuBB8KPa7qgBhGJQBW1jfzq3Xx+tal1AHjSsD6sXDadT08dSnyUrwEQSTwNAjPrBSwC7myz+wHgt2Z2G1AAXOdlDSISeXJKa3h0Yy7Pby+kvinEZRMGcfu8TGaPGqgBYB94GgTOuTpg4Mf2ldN6FZGIBIhzji15laxan8O6/cXEx8ZwzYzh3DY3kzGDkv0uL9B0Z7GIeMo5x7uHy1n5+kE251UwICmBr102lptnjyQl4NM/RwoFgYh44uMBMKRPIj9YMonrs9J0B3CEURCIyDnVXgD8cMkkrp+ZphlAI5SCQETOCQVA96UgEJFPRF1A3Z+CQETOigIgeigIRKRTjlTUsXZfMX/YfYwdBScUAFFAQSAip+WcY//xatbuLebVvcfZV1QFwIQhyXz/s5NYNlMB0N0pCETk/9MScmwvqGTt3uO8ureYgoo6zOCC9P58Z/F5XDFpMCMHJvldppwjCgIRAaChuYV3DpWzdt9xXttXTFlNIwmxMVw8ZiB3LRjN5ecNJjVZN4BFIwWBSMAVVtbxy78c5uVdx6hpaKZ3jzgWjE/l7yYNYcH4VJI1+2fUUxCIBFTRhyf55V8O8eyWIxjG0vOHceWUoVw8eqCu+w8YBYFIwJRU1fPfbx7mmewCHI5lM9P48oIxDNPCL4GlIBAJiNLqBh566zC/2pRPS8hxXdYIvnLpGEb07+V3aeIzBYFIlKuobeSh9Yd58p18GppbuHrGCL5+2VjSByoApJWCQCRKnahr5OENOTzxdh4nm1pYMn04X184lswUXfYp/y8FgUiUKa6q5+nsAh7bmEttYzNXTR3G3QvHaPEXOSUFgUgUqK5v4pU9x1mz8xjvHC4j5GDxlCHcvXAc44coAOT0FAQi3VRjc4i3DpTy4s6jvL6vmIbmEOkDevHVS8ew9PzhjErt7XeJ0k0oCES6kVDIsa2gkhd3HOWP7xVxoq6JAUkJLJuZxpLpw5mR3k+Lv0unKQhEuoGDxdW8uPMoa3Yeo7DyJInxMVwxcQifO384c8emEB8b43eJ0o0pCEQilHOOde+X8PM3DrK78ENiDOaOTeWeReO4YtIQevfQ21fODf0kiUQY5xxvHyrn39d+wM4jJxg5sBf/fNVEPjNtKIOSE/0uT6KQgkAkgmzJq+DfX/2A7NwKhvVN5IGrp3DNBSPU9SOeUhCIRIDdhSf4j7UHeOtAKanJPfj+Zydxwywt+i5dQ0Eg4qP9x6v4z9cO8OreYvr1iudbV07g5tkZ9ExQAEjXURCI+CCntIaVrx/k5d3H6J0Qxz9cPo5b52Zo7n/xhYJApAvlltXyP28e4vntR0mIjeFLl4zmzvmj6Ncrwe/SJMAUBCIeO3biJH/cXcTLu4+xu/BDEuJiWD47g7sWjNbSjxIRFAQiHiitbuDPe4p4edcxtuRVAjBleF++vXgCS6YPZ3AfXQYqkUNBIHKOnKhr5NW9x3l5V9FfJ34bPziZb14xjqumDiND0z9LhFIQiHwCNQ3NvL6vmJd3HWP9wVKaWhwZA3vxlUvHcNXUYZr5U7oFBYHIWThQXM0jG3JYs/MYDc0hhvVN5JY5mXxm6jAmD++jid+kW1EQiHSQc46Nh8p4ZEMubx0oJTE+hmsuGMHV5w9nRnp/YmL04S/dk4JA5Awam0O8tOsYj2zIYf/xalJ69+D/LBrHFy4ayYAkXfYp3Z+CQOQUTtQ18nR2AavfyaOkuoHxg5P5ybVTWTJ9mKZ+kKjiaRCYWT/gEWAy4IBbgQ+AZ4EMIA+43jlX6WUdIp2RX17LoxtzeW5rISebWpg3NoUHr5vG/LEp6vuXqOT1GcHPgFecc9eaWQLQC/g2sM4594CZ3QfcB9zrcR0ip+WcY1t+JQ9vyGHtvmLiYowl04dz+7xMJgzp43d5Ip7yLAjMrA8wH1gB4JxrBBrNbAmwIPxlq4E3URCIT+qbWnhp1zFWv5PH3mNV9O0Zz5cXjGb57AwG6aYvCQgvzwhGAaXA42Y2DdgG3A0Mds4VATjnisxsUHsvNrM7gDsA0tPTPSxTgqiwso6nNhXw7JYCKuuaGDe4N/cvnczVM4bTK0FDZxIsXv7ExwEzgK8557LN7Ge0dgN1iHNuFbAKICsry3lTogSJc453D5ez+t08XttXDMAVE4dw88UjmT1qoPr/JbC8DIJCoNA5lx3e/h2tQVBsZkPDZwNDgRIPaxChtqGZ3+84ypPv5HGwpIb+veK585LR/P1FIxner6ff5Yn4zrMgcM4dN7MjZjbeOfcBsBDYF/6zHHgg/LjGqxok2HLLavnVu/k8t+0I1fXNTB7ehwevncpnpg0jMV6Xf4p8xOvO0K8BT4evGMoBbgFigN+a2W1AAXCdxzVIgDjn2JRTwcMbcnhjfwlxMcbiKUNZfvFIZqT3V/ePSDs8DQLn3E4gq52/Wujl95XgaQk5XtlznFXrD7Or8EMGJCXw9YVj+fsL03X1j8gZ6PII6dZONrbw3LYjPLIhl4KKOjIG9uL+pZO59oIR6v4R6SAFgXRL5TUNPPluPk++m0dlXRPT0/rx7cUTWDRxCLGa/E2kUxQE0q3kldXyyMYcnttaSENziMvPG8Qd80czM0P9/yJnS0Eg3cKOgkpWrc/hlb3HiY+J4XPnD+eL8zMZM0gLv4h8UgoCiVgVtY2s2XmU57cXsudoFX0S47jrktGsuFjTP4icSwoCiSiNzSH+8kEJv9tWyF/2l9Acckwa1ofvfWYi12al0buHfmRFzjW9q8R3zjn2HK3i+e2FrNl5lMq6JlKTe3Dr3EyunjFcs3+KeExBIL4pqarnhR2tXT8HimtIiIth0cTBXDtjBPPGphAXG+N3iSKBoCCQLtUScvx5TxHPbS1kw8FSQg5mpPfjXz83maumDKNvr3i/SxQJHAWBdAnnHG/sL+HHr+znQHENw/om8uUFY7h6xnBGpfb2uzyRQFMQiOe2F1TywJ/3szm3gsyUJH75+RlcOXkIMbrxSyQiKAjEM4dLa3jwlQ94Ze9xUnr34IdLJ3PDzDTi1fcvElE6FARmlgScdM6FzGwcMAH4s3OuydPqpFsqqapn5bqDPLvlCIlxMdyzaBy3zc0kSZd+ikSkjr4z1wPzzKw/sA7YCiwDvuBVYdL9VNU3seqtHB7ZmENLyHHTRSP56mVjSOndw+/SROQ0OhoE5pyrC68h8F/OuZ+Y2Q4vC5Puo6G5hac2FfCLNw5SWdfEZ6YN45tXjGPkwCS/SxORDuhwEJjZbFrPAG7r5GslSjU0t/DSzmP8bN1BCitPMmfMQO771HlMGdHX79JEpBM6+mH+DeBbwAvOub1mNgr4i2dVSUQrrqrn6U35PLO5gLKaRiYN68OPrp7CvLGpfpcmImehQ0HgnHsLeAvAzGKAMufc170sTCLPjoJKnngnjz/uLqLFORZOGMSKizOZM2agpoAW6cY6etXQM8CXgBZgG9DXzH7qnHvQy+LEf43NIf70XhGPv5PHriMnSO4Rx82zM1h+8UiNAYhEiY52DU10zlWZ2ReAPwH30hoICoIoVVrdwDPZBTyVnU9pdQOjUpL4wZJJXD1jhGYAFYkyHX1Hx5tZPLAU+IVzrsnMnHdliV/eK/yQx9/J5Q+7imhsCbFgfCorLs5g/thU3QksEqU6GgQPAXnALmC9mY0EqrwqSrretvwKfvraAd4+VE5SQiw3zkpj+cUZmgdIJAA6Olj8c+DnbXblm9ml3pQkXWnnkRP852sHeOtAKSm9E/jO4vNYNiuNPomaBVQkKDo6WNwX+BdgfnjXW8APgA89qks8tufoh6x8/QCvv19C/17x3HflBG6ePZJeCer/Fwmajr7rHwP2ANeHt28CHgeu9qIo8c7+41WsfO0gr+w9Tp/EOL55xThWzMnUALBIgHX03T/aOXdNm+3vm9lOD+oRjxwqqWHl6wf443tF9E6I4+6FY7l1biZ9e6oLSCToOhoEJ81srnNuI4CZzQFOeleWnCt5ZbX8bN1B1uw8SmJ8LF9eMJovzhtFv14JfpcmIhGio0HwJeDJ8FgBQCWw3JuS5Fw4XFrD/33zML/fcZT4WOP2eaO4c/4oBmomUBH5mI5eNbQLmGZmfcLbVWb2DWC3h7XJWdheUMlDbx1m7b5i4mNjuHn2SO5aMJpByYl+lyYiEapTI4TOubb3DtwDrDyn1chZCYVa1wNetT6HzXkV9O0Zz1cvHcPNszNITdYZgIic3ie5VES3mfqsobmFNTuPsWp9DodKahjeryffvWoiy2amaTUwEemwT/JpoSkmfFJV38Svswt47O1ciqsamDAkmZXLpvPpqUO1HrCIdNppg8DMqmn/A9+Anp5UJKdUXFXPY2/n8symAqobmpkzZiAPXjuNeWNTNA20iJy10waBcy65qwqRU/vwZBMPvrqfZ7ccoSXkWDxlKHfOH62VwETknFBHcoR7fV8x33nxPcpqGrlxVhp3zBtN+sBefpclIlFEQRChymsa+P7L+3hp1zEmDEnmkZtn6gxARDzhaRCYWR5QTevKZs3OuSwzGwA8C2TQOrX19c65Si/r6E6cc7y8u4jvvbSX6vom7lk0ji9dMpqEOA0Ci4g3uuKM4FLnXFmb7fuAdc65B8zsvvD2vV1QR8QrrqrnOy/s4fX3i5mW1o+fXDOV8UM0TCMi3vKja2gJsCD8fDXwJgEPAuccv916hPv/+D6NzSG+s/g8bp2bSaxWBBORLuB1EDhgbXhZy4ecc6uAwc65IgDnXJGZDWrvhWZ2B3AHQHp6usdl+udIRR3f+v17bDxUxoWZA/jxNVPJSNGi8CLSdbwOgjnOuWPhD/vXzGx/R18YDo1VAFlZWVF381oo5Hjy3Tx+8uoHxJhx/9LJfH5WutYFFpEu52kQOOeOhR9LzOwFYBZQbGZDw2cDQ4ESL2uIRHlltXzzuV1sza9kwfhU/u1zUxjWT/fniYg/PLsUxcySzCz5o+fAFbSucvYSf5vCejmwxqsaItHbh8r47C82crCkhp9eP43HV8xUCIiIr7w8IxgMvBCe+iAOeMY594qZbQF+a2a3AQXAdR7WEFGezs7nu2v2Mjo1iUeXzyRtgG4MExH/eRYEzrkcYFo7+8uBhV5930jU3BLi/j++zxPv5HHp+FR+fuP5JCdqiUgRiQy6s9hjVfVNfPWZHaw/UMptczP59uLzdFmoiEQUBYGH8struW31VvLKavnR1VO4cVb0XgYrIt2XgsAjm3LK+dJT2wD41W0XMnv0QJ8rEhFpn4LAA7/dcoTvvPge6QN68ejymbpBTEQimoLgHGoJOX78yn5Wrc9h3tgUfvH5GfTtqUFhEYlsCoJzpKahmbt/vYN1+0tYPnsk/3zVROK0bKSIdAMKgnPgSEUdt6/eyqHSGn64ZBI3zc7wuyQRkQ5TEHxCB4qruXHVJhpbQjxxy0zmjU31uyQRkU5REHwCxVX1rHhsMzExxgt3zmHMoN5+lyQi0mnqxD5L1fVNrHh8Cx+ebOLxFTMVAiLSbemM4Cw0Noe466ntHCyu5rEVM5k8XGsJi0j3pSDoJOcc9/1+NxsPlfHgtVOZP05jAiLSvalrqJN++toBfr/9KPcsGsd1WWl+lyMi8okpCDrhmewC/uuNQ9wwM42vXTbG73JERM4JBUEHvbG/mH9es4dLx6dy/9LJhNdZEBHp9hQEHbC78ARfeXoHE4f24Refn6E7hkUkqugT7QwKyuu49YktpCQn8NiKmST10Pi6iEQXBcFpVNQ2svzxzTSHHE/cMovU5B5+lyQics7p19tTqG9q4fbVWzh64iTP3H4ho1N1w5iIRCedEbSjJeS4+zc72HHkBD9bNp2sjAF+lyQi4hkFwcc45/jBy3t5dW8x371qIldOGep3SSIinlIQfMyjG3NZ/W4+X5yXyS1zMv0uR0TEcwqCNppaQvzH2gMsnDCIb115nt/liIh0CQVBG7sLP+RkUwvXXjCCmBjdMCYiwaAgaCM7txyAmZkaHBaR4FAQtJGdU8GYQb1J6a37BUQkOBQEYc0tIbbmVXChzgZEJGAUBGF7j1VR29jChaMG+l2KiEiXUhCEfTQ+cJHOCEQkYBQEYZtzK8hMSWJQn0S/SxER6VIKAlqnlNicq/EBEQkmBQGw/3gVVfXNXDhKQSAiwaMgoPWyUYALMzVQLCLBoyCgdaB4RP+eDOvX0+9SRES6XOCDIPTX8QGdDYhIMAU+CA6W1FBZ16TxAREJLM+DwMxizWyHmf0hvD3AzF4zs4Phx/5e13A6f7t/QGcEIhJMXXFGcDfwfpvt+4B1zrmxwLrwtm+ycysY2jeRtAEaHxCRYPI0CMxsBPBp4JE2u5cAq8PPVwNLvazhdJxzZOe03j9gpmmnRSSYvD4jWAn8ExBqs2+wc64IIPw4qL0XmtkdZrbVzLaWlpZ6UlxOWS1lNQ2aX0hEAs2zIDCzq4AS59y2s3m9c26Vcy7LOZeVmpp6jqtr9bf7BzRQLCLBFefhvz0H+KyZLQYSgT5m9hRQbGZDnXNFZjYUKPGwhtPKzi0nNbkHmSlJfpUgIuI7z84InHPfcs6NcM5lADcAbzjn/h54CVge/rLlwBqvajhDfWTnVDBL4wMiEnB+3EfwALDIzA4Ci8LbXa6goo7jVfWadlpEAs/LrqG/cs69CbwZfl4OLOyK73s6fx0f0ECxiARcYO8s3pRbzoCkBMYO6u13KSIivgpsEGzOrWBWhsYHREQCGQRHT5yksPKk5hcSESGgQZCd0zq/kGYcFREJbBBU0CcxjglDkv0uRUTEd8EMgtxyZmUOICZG4wMiIoELguKqevLK69QtJCISFrgg2PTR+IAGikVEgAAGQXZuBb17xDFxaB+/SxERiQjBC4KccrIy+hMXG7imi4i0K1CfhmU1DRwurdX4gIhIG4EKgs25H80vpPEBEZGPBCoIsnPK6ZUQy5Thff0uRUQkYgQrCHIruGBkf+I1PiAi8leB+USsrG1k//FqZmWoW0hEpK3ABMHmPK0/ICLSnsAEQXZOBT3iYpiWpvEBEZG2ghMEueWcn96PHnGxfpciIhJRAhEEVfVN7Cuq0v0DIiLtCEQQbM2rwDndPyAi0p5ABEF2TgUJsTHMSO/vdykiIhEnEEGwKbeCaWl9SYzX+ICIyMdFfRDUNDSz5+iHzMpUt5CISHuiPgi25VfSEnIaKBYROYWoD4LsnHJiY4wLRmp8QESkPdEfBLkVTBnel6QecX6XIiISkaI6CE42trC78IQuGxUROY2oDoIdBZU0tTgu0viAiMgpRXUQbMqtIMYgK0PjAyIipxLVQTC8XyLXXjCC5MR4v0sREYlYUT2CumxmOstmpvtdhohIRIvqMwIRETkzBYGISMApCEREAk5BICIScAoCEZGAUxCIiAScgkBEJOAUBCIiAWfOOb9rOCMzKwXyz/LlKUDZOSynuwly+9X24Apy+9u2faRzLvVML+gWQfBJmNlW51yW33X4JcjtV9uD2XYIdvvPpu3qGhIRCTgFgYhIwAUhCFb5XYDPgtx+tT24gtz+Trc96scIRETk9IJwRiAiIqehIBARCbioDgIz+5SZfWBmh8zsPr/r6Upmlmdm75nZTjPb6nc9XjOzx8ysxMz2tNk3wMxeM7OD4ceoXLP0FG3/npkdDR//nWa22M8avWJmaWb2FzN738z2mtnd4f1BOfanan+njn/UjhGYWSxwAFgEFAJbgBudc/t8LayLmFkekOWcC8RNNWY2H6gBnnTOTQ7v+wlQ4Zx7IPyLQH/n3L1+1umFU7T9e0CNc+7f/azNa2Y2FBjqnNtuZsnANmApsIJgHPtTtf96OnH8o/mMYBZwyDmX45xrBH4DLPG5JvGIc249UPGx3UuA1eHnq2l9g0SdU7Q9EJxzRc657eHn1cD7wHCCc+xP1f5OieYgGA4cabNdyFn8B3VjDlhrZtvM7A6/i/HJYOdcEbS+YYBBPtfT1b5qZrvDXUdR2TXSlpllAOcD2QTw2H+s/dCJ4x/NQWDt7IvOfrD2zXHOzQCuBL4S7j6Q4PgfYDQwHSgC/sPXajxmZr2B54FvOOeq/K6nq7XT/k4d/2gOgkIgrc32COCYT7V0OefcsfBjCfACrV1lQVMc7kP9qC+1xOd6uoxzrtg51+KcCwEPE8XH38ziaf0QfNo59/vw7sAc+/ba39njH81BsAUYa2aZZpYA3AC85HNNXcLMksIDR5hZEnAFsOf0r4pKLwHLw8+XA2t8rKVLffQhGPY5ovT4m5kBjwLvO+d+2uavAnHsT9X+zh7/qL1qCCB8ydRKIBZ4zDn3r/5W1DXMbBStZwEAccAz0d52M/s1sIDWKXiLgX8BXgR+C6QDBcB1zrmoG1Q9RdsX0Not4IA84M6P+syjiZnNBTYA7wGh8O5v09pPHoRjf6r230gnjn9UB4GIiJxZNHcNiYhIBygIREQCTkEgIhJwCgIRkYBTEIiIBJyCQAQws5Y2MzXuPJez1ZpZRtuZQUUiTZzfBYhEiJPOuel+FyHiB50RiJxGeF2HH5vZ5vCfMeH9I81sXXhSr3Vmlh7eP9jMXjCzXeE/F4f/qVgzezg8Z/xaM+vpW6NEPkZBINKq58e6hpa1+bsq59ws4Be03qlO+PmTzrmpwNPAz8P7fw685ZybBswA9ob3jwV+6ZybBJwArvG0NSKdoDuLRQAzq3HO9W5nfx5wmXMuJzy513Hn3EAzK6N1QZCm8P4i51yKmZUCI5xzDW3+jQzgNefc2PD2vUC8c+7+LmiayBnpjEDkzNwpnp/qa9rT0OZ5CxqfkwiiIBA5s2VtHt8NP3+H1hltAb4AbAw/XwfcBa3LpZpZn64qUuRs6bcSkVY9zWxnm+1XnHMfXULaw8yyaf3F6cbwvq8Dj5nZPwKlwC3h/XcDq8zsNlp/87+L1oVBRCKWxghETiM8RpDlnCvzuxYRr6hrSEQk4HRGICIScDojEBEJOAWBiEjAKQhERAJOQSAiEnAKAhGRgPtfjhPLr7ZsE6kAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(loss_hist[\"train accuracy\"])\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY8AAAEGCAYAAACdJRn3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAldUlEQVR4nO3deXxUVZ738c8vCwk7JGHLQhJCUFBZA0IAFXdtFZdWcUWlpcdBxbYX2/H1PDPPPOPz2O30uHS3rbQiuICija1tixsuCAFDWGRT1oQkECDsSxaynPkjRZvBAKmQqptUfd+vF69UTt2q+l2v8K1zzr3nmnMOERERf0R4XYCIiLQ+Cg8REfGbwkNERPym8BAREb8pPERExG9RXhcQKAkJCS4tLc3rMkREWpVly5btds51O9V2IRseaWlp5OXleV2GiEirYmZbG7Odhq1ERMRvCg8REfGbwkNERPym8BAREb8pPERExG8KDxER8ZvCQ0RE/KbwOM6riwv42zfbvS5DRKRFC9mLBJvqrWXFxEZFcvWgRK9LERFpsdTzOM6ojHhWFO2j/GiN16WIiLRYCo/jZGckUFXjyNu61+tSRERaLIXHcYandSUqwsjZvMfrUkREWiyFx3HatYliSO8uCg8RkZNQeDRgVEYCq4v3c7CiyutSRERapICFh5lNN7NdZrbmuPYHzGy9ma01s9/Wa3/UzDb5nrusXvswM1vte+5ZM7NA1XxMdkY8tQ5yt2jeQ0SkIYHsecwALq/fYGbjgPHAQOfcWcB/+toHABOAs3yvec7MIn0v+xMwGcj0/fkf7xkIQ3p3ISYqQkNXIiInELDwcM4tAI7/6n4f8IRzrtK3zS5f+3jgDedcpXMuH9gEjDCzXkAn59xi55wDXgGuDVTNx8RERTI8LY6czbsD/VEiIq1SsOc8+gFjzexrM/vSzIb72pOAonrbFfvaknyPj28PuFEZ8Xy34xB7DlcG4+NERFqVYIdHFNAVGAn8Epjjm8NoaB7DnaS9QWY22czyzCyvtLT0tAodlREPwBLNe4iI/ECww6MYmOvq5AK1QIKvPaXedsnAdl97cgPtDXLOTXPOZTnnsrp1O+X9209qYFJnOsREaehKRKQBwQ6PvwIXAphZP6ANsBt4D5hgZjFmlk7dxHiuc64EOGRmI309lDuBd4NRaFRkBCPS41isSXMRkR8I5Km6s4HFwBlmVmxmk4DpQB/f6btvABN9vZC1wBxgHfAhMMU5d2xxqfuAF6mbRN8MzAtUzcfLzohny+4jlBwoD9ZHioi0CgFbVdc5d8sJnrr9BNs/DjzeQHsecHYzltZox+Y9Fm/ew/VDk0+xtYhI+NAV5ifRv2cnurSL1tCViMhxFB4nERFhjOoTT87mPdRdZiIiIqDwOKXsjHi27S+naK/mPUREjlF4nMKojAQAnbIrIlKPwuMUMrq1p3vHGK1zJSJSj8LjFMyM7AzNe4iI1KfwaITsjAR2H65k067DXpciItIiKDwa4dj1Hhq6EhGpo/BohJS4dqTEtdWkuYiIj8KjkbL7JLBky15qajXvISKi8Gik7L7xHCiv4tuSg16XIiLiOYVHI43qc2zeQ0NXIiIKj0bq3imWjG7tNWkuIoLCwy/ZGQnk5u+lqqbW61JERDyl8PBDdkY8ZUdrWFV8wOtSREQ8pfDww8g+x+7voXkPEQlvCg8/dG3fhgG9OmneQ0TCnsLDT9kZ8eRt3UdFVc2pNxYRCVEKDz9l943naHUtywv3eV2KiIhnFB5+Gp4WR2SE6da0IhLWFB5+6hgbzcDkzpr3EJGwpvBoguyMeL4p2s/hymqvSxER8YTCowmyMxKornUsLdjrdSkiIp5QeDTBsNSutImM0LyHiIQthUcTxEZHMjS1ixZJFJGwpfBoouyMBNZuP8j+sqNelyIiEnQKjybKzojHOViyRfMeIhJ+FB5NNDC5C+3aRGqdKxEJSwqPJmoTFUFWWhyLt2jSXETCj8LjNGRnxLNh52FKD1V6XYqISFApPE5DdoZviXb1PkQkzCg8TsNZiZ3pGBuleQ8RCTsKj9MQGWGM7BOvda5EJOwoPE5TdkY8W/eUUbyvzOtSRESCJmDhYWbTzWyXma1p4LlfmJkzs4R6bY+a2SYzW29ml9VrH2Zmq33PPWtmFqiam2JsZt0u/OGzTTjnPK5GRCQ4AtnzmAFcfnyjmaUAlwCF9doGABOAs3yvec7MIn1P/wmYDGT6/vzgPb3Ut3tH7h/XlzeWFjFtwRavyxERCYqAhYdzbgHQ0OXXTwG/Aup/TR8PvOGcq3TO5QObgBFm1gvo5Jxb7Oq+1r8CXBuompvq4Uv6cfWgRP7/vO/4YHWJ1+WIiARcUOc8zOwaYJtz7pvjnkoCiur9XuxrS/I9Pr79RO8/2czyzCyvtLS0mao+tYgI48kfD2RYald+9uZK3aJWREJe0MLDzNoBjwH/u6GnG2hzJ2lvkHNumnMuyzmX1a1bt6YV2kSx0ZFMu2MYPTvHcu/MPIr2agJdREJXMHseGUA68I2ZFQDJwHIz60ldjyKl3rbJwHZfe3ID7S1SfIcYpt81nOpax10v53KgrMrrkkREAiJo4eGcW+2c6+6cS3POpVEXDEOdczuA94AJZhZjZunUTYznOudKgENmNtJ3ltWdwLvBqrkpMrp14IU7hlG4t4z7Xl/G0epar0sSEWl2gTxVdzawGDjDzIrNbNKJtnXOrQXmAOuAD4Epzrka39P3AS9SN4m+GZgXqJqby8g+8fzmhoHkbN7DY++s1im8IhJyogL1xs65W07xfNpxvz8OPN7AdnnA2c1aXBBcPzSZrXvKeGb+RtIS2jNlXF+vSxIRaTYBCw+Bhy7OpHBvGU9+tJ6UuHZcMyjR65JERJqFwiOAzIwnbjiHbfvK+cVb35DYOZastDivyxIROW1a2yrAYqIieeGOYSR1acu9r+RRsPuI1yWJiJw2hUcQdG3fhpfvGg7APTOWsr/sqMcViYicHoVHkKQltGfanVkU7ytn8qvLqKyuOfWLRERaKIVHEA1Pi+PJGweSm7+XR95eRXWNrgERkdZJE+ZBNn5wEsX7ynnyo/XsOFjBH24dSkKHGK/LEhHxi3oeHpgyri+/u3EQKwr3c/XvF7JCCymKSCuj8PDIDcOS+ct92URGGDe/sITXv96qK9FFpNVQeHjo7KTOvP/AGEZlxPPYO2v41durqKjSRLqItHwKD491adeG6XcN58EL+/LWsmJ+/HyO7ocuIi2ewqMFiIwwHr70DP58ZxZbd5dx9e8X8tXG4N3MSkTEXwqPFuSSAT1474ExdO8Yy8Tpufzx802aBxGRFknh0cKkJ7TnnSnZ/GhgIk9+tJ6fvrqMQxW6qZSItCwKjxaoXZsonp0wmP911QDmf7eL8X9YxMadh7wuS0TkHxQeLZSZMWlMOq//5FwOVlQx/o+LmLu8WMNYItIiKDxauJF94nn/gbH079WJh+d8w+0vfc3m0sNelyUiYU7h0Qr07BzLnJ+O4t/Hn8WqogNc8fRX/O7j9bomREQ8o/BoJSIjjDtHpTH/F+dz5Tk9+f1nm7j0qQV8vn6X16WJSBhSeLQy3TvG8vSEIcz6yblERRp3v7yU+15bRsmBcq9LE5EwovBopbL7JjBv6lh+edkZfPbdLi763Zf8ecEWqrTMu4gEgcKjFYuJimTKuL58+vD5jOwTz+MffMvVv19IXsFer0sTkRCn8AgBKXHteGliFi/cMYyD5VX8+PnFPPL2KvYe0e1uRSQwFB4hwsy47KyefPLw+fz0vD78ZXkxF/3uC15amK+zskSk2Sk8Qkz7mCgevbI/f39wLAMSO/F/31/H+U9+zquLC3TfdBFpNhaqVyxnZWW5vLw8r8vw3OLNe/ivT9aztGAfSV3acv+FffnxsGSiI/W9QUR+yMyWOeeyTrmdwiP0OedYuGk3v/t4AyuL9pMS15YHL8zkuiFJRClERKSexoaH/uUIA2bG2MxuvPPP2Uy/K4vObaP55duruPSpBby7chs1taH5BUJEAkfhEUbMjAvP7MHf7h/DC3cMo01UBFPfWMnlTy/g76tKqFWIiEgjKTzC0LEzsz54cCx/uHUIDpgyazlXPvsVH67ZoRARkVPSnIdQU+v42zfbeWb+RvJ3HyGjW3vuHduHa4ckERsd6XV5IhJEmjBXePituqaWv68uYdqCLazdfpCEDjHclZ3K7SNT6dKujdfliUgQKDwUHk3mnCNn8x6mLdjClxtKaRsdyc3DU5g0Jp2UuHZelyciAdSsZ1uZWXszi/A97mdm15hZ9CleM93MdpnZmnptT5rZd2a2yszeMbMu9Z571Mw2mdl6M7usXvswM1vte+5ZM7PG1CxNZ2aM7pvAzHtG8OFDY7ninJ68tmQr5z/5OVNmLWdV8X6vSxQRjzWq52Fmy4CxQFdgCZAHlDnnbjvJa84DDgOvOOfO9rVdCnzmnKs2s98AOOceMbMBwGxgBJAIfAr0c87VmFkuMNX3uR8Azzrn5p2qZvU8mlfJgXJmLCpg1teFHKqs5tz0OH56fh8u6NediAjluUioaO7rPMw5VwZcD/zeOXcdMOBkL3DOLQD2Htf2sXOu2vfrEiDZ93g88IZzrtI5lw9sAkaYWS+gk3NusatLuVeAaxtZszSjXp3b8uiV/cl59EIeu7I/hXvLuGdGHpc9vYA5eUUcrdZS8CLhpNHhYWajgNuAv/vaok7zs+8BjvUgkoCies8V+9qSfI+PbxePdIyN5t7z+rDgV+N46uZBREYYv3p7Fec/+TnTF+ZTdrT61G8iIq1eY8PjIeBR4B3n3Foz6wN83tQPNbPHgGrg9WNNDWzmTtJ+ovedbGZ5ZpZXWlra1PKkEaIjI7huSDLzpo7l5buHk9K1Hf/+/jpGP/EZz87fyIGyKq9LFJEAalTvwTn3JfAlgG/ifLdz7sGmfKCZTQSuAi5y30+4FAMp9TZLBrb72pMbaD9RndOAaVA359GU+sQ/Zsa4M7oz7ozu5BXs5bkvNvNfn2zghS83c9vIVH4yJp3unWK9LlNEmlljz7aaZWadzKw9sA5Yb2a/9PfDzOxy4BHgGt8cyjHvARPMLMbM0oFMINc5VwIcMrORvrOs7gTe9fdzJTiy0uKYftdw5k0dy0X9e/DiV1sY85vPeXTuarbuOeJ1eSLSjBp7ttVK59xgM7sNGEZdACxzzg08yWtmAxcACcBO4F+pG/qKAfb4NlvinPsn3/aPUTcPUg08dOyMKjPLAmYAbambI3nANaJonW3lvYLdR3hhwRb+sqyY6tpafjQwkX++IIP+vTp5XZqInECzXiRoZmuBwcAs4A/OuS/N7Bvn3KDTrjRAFB4tx86DFby0MJ/Xl2zlyNEaLjyzO1PG9WVYalevSxOR4zT3qbovAAVAe2CBmaUCB5tenoSTHp1i+Zcr+5Pz64t4+JJ+rCjcxw1/yuGOl75macHeU7+BiLQ4TV6exMyi6l2z0eKo59FyHams5rUlW5m2YAt7jhxlVJ94pl6cycg+8V6XJhL2mnvYqjN1cxbn+Zq+BP7dOXfgtKoMIIVHy1d+tIbXv97KCwu2UHqokhHpcTx0USajMuLRKjQi3mju8PgLsAaY6Wu6AxjknLv+tKoMIIVH61FRVcPs3EKe/3IzOw9WkpXalQcvymRsZoJCRCTImjs8VjrnBp+qrSVReLQ+FVU1vJVXxHNfbKbkQAVDenfhwYsyuaBfN4WISJA094R5uZmNqffmo4HyphYn0pDY6EjuGJXGF7+8gMevO5tdByu5++WljP/jIj5Zt1N3OBRpQRrb8xhE3aKEnX1N+4CJzrlVAazttKjn0fodra5l7vJi/vjFJor2ltOnW3vuGZ3ODUOTadtGdzgUCYSA3AzKzDoBOOcOmtlDzrmnm15iYCk8QkdVTS1/X1XCSwvzWb3tAF3aRXPriN5MzE6jh5Y+EWlWAb+ToJkVOud6N+nFQaDwCD3OOZYW7OOlhVv4eN1OoiKMqwYmMmlMOmcndT71G4jIKTU2PE5nWXXNYEpQmRkj0uMYkR5H4Z4yXs7JZ87SIt5ZsY0R6XFMGpPOxf17EKmbU4kEnHoe0qodrKjizdwiZuQUsG1/Oanx7bg7O40bs1JoH3O6t5wRCT/NMmxlZodo+P4ZBrR1zrXYv50Kj/BSXVPLR2t38tLCLSwv3E/H2Kh/zIskdmnrdXkirUbA5zxaOoVH+FpeuI+XFubz4ZodAPzonF78ZGw6A5O7eFuYSCsQjDkPkRZpaO+uDL21K8X7ypiZU8AbuUW89812RqTFMWms5kVEmoN6HhLyDlVUMSevmOkL8/8xL3LP6HR+PCxZ8yIix9GwlcJDjnNsXuTFhVtYUbifTrFR3HpuKhOzU+nVWfMiIqDwUHjISS3bWne9yIdrdhBhxtWDdL2ICGjOQ+SkhqV2ZVjqMIr2lvHyogLeXFr4j+tF7s5O45IBPYiKbOzSbyLhRz0PEb6/XmTm4gKK95WT2DmWO0alMWF4Cl3bt/G6PJGg0bCVwkOaoKbWMf/bncxcXMCiTXuIiYrg2sFJ3DU6jf69OnldnkjAKTwUHnKa1u84xMzFBcxdXkxFVS3npsdx9+g0Lu6vIS0JXQoPhYc0k/1lR5mTV8TMnK1s219OUpe23DEqlQnDU+jSTkNaEloUHgoPaWY1tY5Pv93JjEUFLN6yh9jouiGtu0enc0bPjl6XJ9IsFB4KDwmg73YcZGZOAe+s2EZFVS1jMxO4Z0w652d2I0JXr0srpvBQeEgQ7DtylFm5hbyyuICdByvJ6Naeu0enc/3QJNq10Znw0vooPBQeEkRHq2uZt6buboerig/QuW00t57bm4mj0ujZWXc7lNZD4aHwEA8458jbuo/pC/P5aG3d1es/GtiLe0anMyili9fliZySrjAX8YCZMTwtjuFpcRTtrVvV982lRby7cjtZqV2ZNCZdV69LSFDPQyTADlVU8VZeMTNyCijcW0bPTrFcOySJ64cm0a+HztKSlkXDVgoPaWGOneo7Z2kRX2wopabWcXZSJ64bksw1gxLp1jHG6xJFFB4KD2nJdh+u5G/fbGfu8m2s3naAyAjjvMwErhuazKUDehAbHel1iRKmFB4KD2klNu06xNzl2/jrim1sP1BBh5gorjynJ9cNSebc9DhdNyJBpfBQeEgrU1vrWJK/h7nLtzFvdQlHjtaQ1KUt4wcncvPwFFLj23tdooQBhYfCQ1qx8qM1fLxuB3OXb+OrjaU44IJ+3bgzO01XsUtAeR4eZjYduArY5Zw729cWB7wJpAEFwE3OuX2+5x4FJgE1wIPOuY987cOAGUBb4ANgqmtE0QoPCRU7D1Yw6+tCZuUWUnqoktT4dtwxMpUbh6XQuV201+VJiGkJ4XEecBh4pV54/BbY65x7wsx+DXR1zj1iZgOA2cAIIBH4FOjnnKsxs1xgKrCEuvB41jk371Sfr/CQUHO0upYP1+7glZwC8rbuIzY6guuGJHHHyDQGJOpeI9I8PL9I0Dm3wMzSjmseD1zgezwT+AJ4xNf+hnOuEsg3s03ACDMrADo55xYDmNkrwLXAKcNDJNS0iYrgmkGJXDMokbXbD/Dq4q28s2Ibs3OLGJEWx53ZqVx2Vk+idQGiBEGw/y/r4ZwrAfD97O5rTwKK6m1X7GtL8j0+vr1BZjbZzPLMLK+0tLRZCxdpSc5K7MwTNwxkyaMX8diV/Sk5WM79s1Yw+onPePrTDew6WOF1iRLiWspXlIZm/9xJ2hvknJvmnMtyzmV169at2YoTaam6tGvDvef14YtfjGP6XVn079WJpz/dyOjffMZj76ym5EC51yVKiAr22lY7zayXc67EzHoBu3ztxUBKve2Sge2+9uQG2kWknsgI48Ize3DhmT3I332EP3+1hTl5RbyVV8yt5/bmny/IoHsnre4rzSfYPY/3gIm+xxOBd+u1TzCzGDNLBzKBXN/Q1iEzG2lmBtxZ7zUi0oD0hPb8v+vO4bOfX8B1Q5J4dclWxv72c/7j/XXsPlzpdXkSIgJ5ttVs6ibHE4CdwL8CfwXmAL2BQuBG59xe3/aPAfcA1cBDx86oMrMsvj9Vdx7wgE7VFWm8rXuO8Mz8jfx1xTZioyOZmJ3G5LF96Npe91+XH/L8VF2vKTxE/qfNpYd55tON/G3Vdtq3ieKe0WlMGtuHzm11rYh8T+Gh8BBp0Podh3hm/gY+WL2DjrFR3Du2D3ePTqNjrEJEFB4KD5FTWLv9AE9/upFP1u2kS7tofjImnVvPTSVOw1lhTeGh8BBplFXF+3nqkw18vr6UNlERXDs4kYnZaZyV2Nnr0sQDCg+Fh4hfNuw8xMycAuYu30Z5VQ0j0uO4OztNt80NMwoPhYdIkxwoq2JOXhEzFxdQvK+cxM6x3D4qlQnDe2tIKwwoPBQeIqelptYx/9udzMgpIGfzHmKiIrh2cBITs7UQYyhTeCg8RJrN+h2HmLm4gLnLi6moqtWQVghTeCg8RJrd/rKjdUNaOVvZtr+cHp1iuDkrhZuGp5DctZ3X5UkzUHgoPEQC5tiQ1uzcQr7YULeC9fn9unHLiN5ceGZ3LQvfiik8FB4iQVG8r4w5ecXMWVrEjoMVdO8Yw01ZKdw8PIWUOPVGWhuFh8JDJKiqa2r5Yn0ps3ML+Xz9LhwwNrMbt45I4aL+PdQbaSUUHgoPEc9s31/OnLwi3lxaRMmBChI6xHBTVjIThvemd7x6Iy2ZwkPhIeK5mlrHgg2lzMot5LPvdlHrHNcOTuLhS/ppSKuF8vwe5iIikRHGuDO7M+7M7uw4UMHLOfnMWFTA31eVcOeoVKaM66ul4Vsp9TxEJKhKDpTz1CcbeHtZMe1jorjvggzuzk6nbZtIr0sTNGyl8BBp4dbvOMSTH33Hp9/uomenWH52SSY3DE3WRYcea2x46CiJiCfO6NmRFycO583JI+nVJZZH/rKaK575ik/W7SRUv9SGEoWHiHjq3D7xzL0vm+dvH0pNrePeV/K46YXFLNu6z+vS5CQUHiLiOTPj8rN78dHPzuPx686mYE8ZN/wph5++msemXYe8Lk8aoDkPEWlxyo5W89JX+bywYAuHK6sZ1SeeW87tzWVn9SAmShPrgaQJc4WHSKu353AlbywtYnZuIcX7yunaLpobhiYzYURv+nbv4HV5IUnhofAQCRm1tY6Fm3YzO7eQT9btpLrWMSI9jltGpHDF2b2IjVZvpLkoPBQeIiGp9FAlby8r5o2lhWzdU0bnttFcPzSJW0b0pl+Pjl6X1+opPBQeIiGtttaxZMseZuUW8tHaHVTVOIalduWWEb25aqB6I02l8FB4iISNPYcrmbt8G7NzC9my+whd20Vz+8hU7hiVSveOsV6X16ooPBQeImHHOceSLXuZviifT7/dSXREBOMHJzJpbDpn9tR91xtDCyOKSNgxM0ZlxDMqI5783Ud4eVE+b+UV89ayYsZmJjBpTDrn9+uGmXldaqunnoeIhLT9ZUeZlVvIzJwCdh6sJLN7ByaNSefaIUmaF2mAhq0UHiJSz9HqWt5ftZ0Xv8pnXclB4tu34Y5Rqdw+MpWEDjFel9diKDwUHiLSAOcci7fs4aWv8pn/3S7aREVw/ZAkJp/Xhz7ddOGh5jxERBpgZmRnJJCdkcCmXYd5eVE+by8rZk5eEVee04sp4/rSv5cm109FPQ8RCXulhyp5aWE+ry3ZyuHKai7u34P7L+zL4JQuXpcWdBq2UniIiJ8OlFUxI6eA6YvyOVBexdjMBKaM68u56XFhc4ZWi74ZlJn9zMzWmtkaM5ttZrFmFmdmn5jZRt/PrvW2f9TMNpnZejO7zIuaRST0dW4XzdSLM1n06wt59Ioz+bbkEBOmLeHG5xfzxfpduklVPUHveZhZErAQGOCcKzezOcAHwABgr3PuCTP7NdDVOfeImQ0AZgMjgETgU6Cfc67mZJ+jnoeInK6Kqhrm5BXx/Beb2X6ggnOSOjNlXF8uHdCDiIjQ7Im06J4HdRP1bc0sCmgHbAfGAzN9z88ErvU9Hg+84ZyrdM7lA5uoCxIRkYCKjY7kzlFpfPHLcfz2hoEcqqjin15bxuXPLODdlduoqQ3fnkjQw8M5tw34T6AQKAEOOOc+Bno450p825QA3X0vSQKK6r1Fsa/tB8xsspnlmVleaWlpoHZBRMJMm6gIbhqewvyfX8CztwzBMKa+sZJLnvoybEMk6OHhm8sYD6RTNwzV3sxuP9lLGmhr8Eg556Y557Kcc1ndunU7/WJFROqJjDCuGZTIvKljef72obSJjGDqGyu5NAxDxIthq4uBfOdcqXOuCpgLZAM7zawXgO/nLt/2xUBKvdcnUzfMJSLiiYiIunuuf/DgWP5021CiIsIvRLwIj0JgpJm1s7pz3y4CvgXeAyb6tpkIvOt7/B4wwcxizCwdyARyg1yziMgPREQYV5zTi3lTx/JcvRC57OkFvPfN9pAOEU+u8zCz/wPcDFQDK4CfAB2AOUBv6gLmRufcXt/2jwH3+LZ/yDk371SfobOtRCTYamsd89bs4Jn5G9iw8zCZ3Tvw4EWZ/OicXq3m7CxdJKjwEBGP1NY6PlhTwjOfbmTjrroQmXpxJlee3fJDpKWfqisiErIiIoyrBiby4UPn8ftbhuCA+2et4IpnvuLjtTtC4mJDhYeISIBERhhXD0rko4fO45kJgzlaU8vkV5dx7XM5LNq02+vyTovCQ0QkwCIjjPGDk/jkZ+fxmxvOofRgBbe9+DW3/nkJywv3eV1ek2jOQ0QkyCqqapj1dSF//HwTe44c5eL+Pfj5pf1axFLwmjBXeIhIC3ekspoZOQU8/+VmDldWc/XARH52ST/SE9p7VpPCQ+EhIq3EgbIqXliwmZcXFXC0ppabspJ54MJMEru0DXotCg+Fh4i0MrsOVfDc55t5/eutmBm3n5vK5PP60LNzbNBqUHgoPESklSreV8az8zfy9rJiIsy4amAv7hmTzsDkLgH/bIWHwkNEWrnCPWXMyClgTl4RhyuryUrtyqQx6Vx6Vk8iA3SxocJD4SEiIeJQRRVz8oqZkZNP0d5ykru25a7sNG4ankKn2Ohm/SyFh8JDREJMTa3jk3U7mb4on9z8vbRvE8mNWSncPTqN1PjmOUNL4aHwEJEQtrr4ANMX5fP+qu1U1zou7t+DSWPSOTc9jroFy5tG4aHwEJEwsPNgBa8t2cprS7ayr6yKAb06MePu4XTv1LQztBobHlFNencREWkRenSK5eeXnsGUcX3564ptfPbdLhI6xAT8cxUeIiIhIDY6kgkjejNhRO+gfJ4WRhQREb8pPERExG8KDxER8ZvCQ0RE/KbwEBERvyk8RETEbwoPERHxm8JDRET8FrLLk5hZKbC1iS9PAHY3YzmtSTjvO4T3/ofzvkN473/9fU91znU71QtCNjxOh5nlNWZtl1AUzvsO4b3/4bzvEN7735R917CViIj4TeEhIiJ+U3g0bJrXBXgonPcdwnv/w3nfIbz33+9915yHiIj4TT0PERHxm8JDRET8pvCox8wuN7P1ZrbJzH7tdT3BZmYFZrbazFaaWUjfw9fMppvZLjNbU68tzsw+MbONvp9dvawxkE6w//9mZtt8x3+lmV3pZY2BYmYpZva5mX1rZmvNbKqvPeSP/0n23e9jrzkPHzOLBDYAlwDFwFLgFufcOk8LCyIzKwCynHMhf6GUmZ0HHAZecc6d7Wv7LbDXOfeE78tDV+fcI17WGSgn2P9/Aw475/7Ty9oCzcx6Ab2cc8vNrCOwDLgWuIsQP/4n2feb8PPYq+fxvRHAJufcFufcUeANYLzHNUmAOOcWAHuPax4PzPQ9nkndX6qQdIL9DwvOuRLn3HLf40PAt0ASYXD8T7LvflN4fC8JKKr3ezFN/I/aijngYzNbZmaTvS7GAz2ccyVQ95cM6O5xPV6438xW+Ya1Qm7Y5nhmlgYMAb4mzI7/cfsOfh57hcf3rIG2cBvTG+2cGwpcAUzxDW1I+PgTkAEMBkqA33laTYCZWQfgL8BDzrmDXtcTTA3su9/HXuHxvWIgpd7vycB2j2rxhHNuu+/nLuAd6obywslO35jwsbHhXR7XE1TOuZ3OuRrnXC3wZ0L4+JtZNHX/eL7unJvraw6L49/Qvjfl2Cs8vrcUyDSzdDNrA0wA3vO4pqAxs/a+CTTMrD1wKbDm5K8KOe8BE32PJwLvelhL0B37h9PnOkL0+JuZAS8B3zrn/qveUyF//E+070059jrbqh7f6WlPA5HAdOfc495WFDxm1oe63gZAFDArlPffzGYDF1C3FPVO4F+BvwJzgN5AIXCjcy4kJ5VPsP8XUDds4YAC4KfH5gBCiZmNAb4CVgO1vuZ/oW7sP6SP/0n2/Rb8PPYKDxER8ZuGrURExG8KDxER8ZvCQ0RE/KbwEBERvyk8RETEbwoPkSYys5p6q5CubM6VmM0srf6KtyItTZTXBYi0YuXOucFeFyHiBfU8RJqZ774ovzGzXN+fvr72VDOb71t8br6Z9fa19zCzd8zsG9+fbN9bRZrZn333XfjYzNp6tlMix1F4iDRd2+OGrW6u99xB59wI4A/UrVqA7/ErzrmBwOvAs772Z4EvnXODgKHAWl97JvBH59xZwH7ghoDujYgfdIW5SBOZ2WHnXIcG2guAC51zW3yL0O1wzsWb2W7qbsRT5Wsvcc4lmFkpkOycq6z3HmnAJ865TN/vjwDRzrn/CMKuiZySeh4igeFO8PhE2zSkst7jGjRHKS2IwkMkMG6u93Ox73EOdas1A9wGLPQ9ng/cB3W3QzazTsEqUqSp9E1GpOnamtnKer9/6Jw7drpujJl9Td0XtFt8bQ8C083sl0ApcLevfSowzcwmUdfDuI+6G/KItFia8xBpZr45jyzn3G6vaxEJFA1biYiI39TzEBERv6nnISIiflN4iIiI3xQeIiLiN4WHiIj4TeEhIiJ++29cAcOdaaykhwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(loss_hist[\"train loss\"])\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy%:  61.55 == 6155 / 10000\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    model.eval()\n",
    "    \n",
    "    y_true_test = []\n",
    "    y_pred_test = []\n",
    "    \n",
    "    for batch_idx, (img, labels) in enumerate(testloader):\n",
    "        img = img.to(device)\n",
    "        label = label.to(device)\n",
    "    \n",
    "        preds, _ = model(img)\n",
    "        \n",
    "        y_pred_test.extend(preds.detach().argmax(dim=-1).tolist())\n",
    "        y_true_test.extend(labels.detach().tolist())\n",
    "        \n",
    "    total_correct = len([True for x, y in zip(y_pred_test, y_true_test) if x==y])\n",
    "    total = len(y_pred_test)\n",
    "    accuracy = total_correct * 100 / total\n",
    "    \n",
    "    print(\"Test Accuracy%: \", accuracy, \"==\", total_correct, \"/\", total)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}